{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bda8cfe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from recommender.models import Recommender\n",
    "from recommender.data_processing import get_context, pad_list, map_column, MASK, PAD\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f687dde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_csv_path = \"/home/jenazzad/ML_DATA/movielens/ml-25m/ratings.csv\"\n",
    "movies_path = \"/home/jenazzad/ML_DATA/movielens/ml-25m/movies.csv\"\n",
    "\n",
    "model_path = \"/home/jenazzad/PycharmProjects/recommender_transformer/recommender_models/recommender.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cac77364",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(data_csv_path)\n",
    "movies = pd.read_csv(movies_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bb1a33bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.sort_values(by=\"timestamp\", inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0a65c910",
   "metadata": {},
   "outputs": [],
   "source": [
    "data, mapping, inverse_mapping = map_column(data, col_name=\"movieId\")\n",
    "grp_by_train = data.groupby(by=\"userId\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6923b93c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[49346, 144069, 123450, 148061, 15851, 158174, 23525, 90203, 88187, 105811]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.sample(list(grp_by_train.groups), k=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80c7cc4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Recommender(\n",
    "        vocab_size=len(mapping) + 2,\n",
    "        lr=1e-4,\n",
    "        dropout=0.3,\n",
    "    )\n",
    "model.eval()\n",
    "model.load_state_dict(torch.load(model_path)[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "997385df",
   "metadata": {},
   "outputs": [],
   "source": [
    "movie_to_idx = {a: mapping[b] for a, b in zip(movies.title.tolist(), movies.movieId.tolist()) if b in mapping}\n",
    "idx_to_movie = {v: k for k, v in movie_to_idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5b083cd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(list_movies, model, movie_to_idx, idx_to_movie):\n",
    "    \n",
    "    ids = [PAD] * (120 - len(list_movies) - 1) + [movie_to_idx[a] for a in list_movies] + [MASK]\n",
    "    \n",
    "    src = torch.tensor(ids, dtype=torch.long).unsqueeze(0)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        prediction = model(src)\n",
    "    \n",
    "    masked_pred = prediction[0, -1].numpy()\n",
    "    \n",
    "    sorted_predicted_ids = np.argsort(masked_pred).tolist()[::-1]\n",
    "    \n",
    "    sorted_predicted_ids = [a for a in sorted_predicted_ids if a not in ids]\n",
    "    \n",
    "    return [idx_to_movie[a] for a in sorted_predicted_ids[:30] if a in idx_to_movie]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f0e44c1",
   "metadata": {},
   "source": [
    "### Senario 1: Adventure/Fantasy "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5dae87be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Ice Age (2002)',\n",
       " \"Pirates of the Caribbean: Dead Man's Chest (2006)\",\n",
       " 'Avatar (2009)',\n",
       " 'Star Wars: Episode III - Revenge of the Sith (2005)',\n",
       " 'Shrek 2 (2004)',\n",
       " 'Ratatouille (2007)',\n",
       " 'Bruce Almighty (2003)',\n",
       " 'I, Robot (2004)',\n",
       " 'Last Samurai, The (2003)',\n",
       " 'Up (2009)',\n",
       " 'Matrix Revolutions, The (2003)',\n",
       " 'Men in Black II (a.k.a. MIIB) (a.k.a. MIB 2) (2002)',\n",
       " 'Iron Man (2008)',\n",
       " 'Spirited Away (Sen to Chihiro no kamikakushi) (2001)',\n",
       " '300 (2007)',\n",
       " 'Big Fish (2003)',\n",
       " \"Bridget Jones's Diary (2001)\",\n",
       " 'My Big Fat Greek Wedding (2002)',\n",
       " 'Pianist, The (2002)',\n",
       " 'Interstellar (2014)',\n",
       " 'Shaun of the Dead (2004)',\n",
       " 'Moulin Rouge (2001)',\n",
       " 'Juno (2007)',\n",
       " 'WALL·E (2008)',\n",
       " 'Casino Royale (2006)',\n",
       " 'School of Rock (2003)',\n",
       " '40-Year-Old Virgin, The (2005)',\n",
       " 'Harry Potter and the Order of the Phoenix (2007)',\n",
       " 'Bourne Supremacy, The (2004)',\n",
       " 'Miss Congeniality (2000)']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_movies = [\"Harry Potter and the Sorcerer's Stone (a.k.a. Harry Potter and the Philosopher's Stone) (2001)\",\n",
    "               \"Harry Potter and the Chamber of Secrets (2002)\",\n",
    "               \"Harry Potter and the Prisoner of Azkaban (2004)\",\n",
    "               \"Harry Potter and the Goblet of Fire (2005)\"]\n",
    "\n",
    "top_movie = predict(list_movies, model, movie_to_idx, idx_to_movie)\n",
    "top_movie"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afb4b657",
   "metadata": {},
   "source": [
    "### Senario 2:  Action/Adventure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "96f0c5d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Avengers: Infinity War - Part II (2019)',\n",
       " 'Deadpool 2 (2018)',\n",
       " 'Thor: Ragnarok (2017)',\n",
       " 'Spider-Man: Into the Spider-Verse (2018)',\n",
       " 'Captain Marvel (2018)',\n",
       " 'Incredibles 2 (2018)',\n",
       " 'Untitled Spider-Man Reboot (2017)',\n",
       " 'Ant-Man and the Wasp (2018)',\n",
       " 'Guardians of the Galaxy 2 (2017)',\n",
       " 'Iron Man 2 (2010)',\n",
       " 'Thor (2011)',\n",
       " 'Guardians of the Galaxy (2014)',\n",
       " 'Captain America: The First Avenger (2011)',\n",
       " 'X-Men Origins: Wolverine (2009)',\n",
       " \"Ocean's 8 (2018)\",\n",
       " 'Wonder Woman (2017)',\n",
       " 'Iron Man 3 (2013)',\n",
       " 'Pirates of the Caribbean: The Curse of the Black Pearl (2003)',\n",
       " 'Amazing Spider-Man, The (2012)',\n",
       " 'Aquaman (2018)',\n",
       " 'Dark Knight, The (2008)',\n",
       " 'Mission: Impossible - Fallout (2018)',\n",
       " 'Avengers: Age of Ultron (2015)',\n",
       " 'Jurassic World: Fallen Kingdom (2018)',\n",
       " 'Iron Man (2008)',\n",
       " 'Coco (2017)',\n",
       " 'Lord of the Rings: The Two Towers, The (2002)',\n",
       " 'Rogue One: A Star Wars Story (2016)',\n",
       " 'X-Men: The Last Stand (2006)',\n",
       " 'Venom (2018)']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_movies = [\"Black Panther (2017)\",\n",
    "               \"Avengers, The (2012)\",\n",
    "               \"Avengers: Infinity War - Part I (2018)\",\n",
    "               \"Logan (2017)\",\n",
    "               \"Spider-Man (2002)\",\n",
    "               \"Spider-Man 3 (2007)\",\n",
    "               \"Spider-Man: Far from Home (2019)\"]\n",
    "\n",
    "top_movie = predict(list_movies, model, movie_to_idx, idx_to_movie)\n",
    "top_movie"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a22ddfeb",
   "metadata": {},
   "source": [
    "### Senario 3: Comedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b3f01cbc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Home Alone (1990)',\n",
       " \"Bug's Life, A (1998)\",\n",
       " 'Toy Story 2 (1999)',\n",
       " 'Nightmare Before Christmas, The (1993)',\n",
       " 'Babe (1995)',\n",
       " 'Inside Out (2015)',\n",
       " 'Mask, The (1994)',\n",
       " 'Toy Story (1995)',\n",
       " 'Back to the Future (1985)',\n",
       " 'Back to the Future Part II (1989)',\n",
       " 'Simpsons Movie, The (2007)',\n",
       " 'Forrest Gump (1994)',\n",
       " 'Austin Powers: International Man of Mystery (1997)',\n",
       " 'Monty Python and the Holy Grail (1975)',\n",
       " 'Cars (2006)',\n",
       " 'Kung Fu Panda (2008)',\n",
       " 'Groundhog Day (1993)',\n",
       " 'American Pie (1999)',\n",
       " 'Men in Black (a.k.a. MIB) (1997)',\n",
       " 'Dumb & Dumber (Dumb and Dumber) (1994)',\n",
       " 'Back to the Future Part III (1990)',\n",
       " 'Big Hero 6 (2014)',\n",
       " 'Mrs. Doubtfire (1993)',\n",
       " 'Clueless (1995)',\n",
       " 'Bruce Almighty (2003)',\n",
       " 'Corpse Bride (2005)',\n",
       " 'Deadpool (2016)',\n",
       " 'Up (2009)',\n",
       " \"Ferris Bueller's Day Off (1986)\"]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_movies = [\"Zootopia (2016)\",\n",
    "               \"Toy Story 3 (2010)\",\n",
    "               \"Toy Story 4 (2019)\",\n",
    "               \"Finding Nemo (2003)\",\n",
    "               \"Ratatouille (2007)\",\n",
    "               \"The Lego Movie (2014)\",\n",
    "               \"Ghostbusters (a.k.a. Ghost Busters) (1984)\",\n",
    "               \"Ace Ventura: When Nature Calls (1995)\"]\n",
    "top_movie = predict(list_movies, model, movie_to_idx, idx_to_movie)\n",
    "top_movie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9434a0e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
